from plib.commands import Command
from plib.errors import Status

from serial import Serial, SerialException
from struct import pack, unpack

from binascii import hexlify
import time

BLOCK_SIZE = 512
CRYPTO_BASE = 0x10210000


class Channel:
    def write(self, data: bytes):
        raise RuntimeError("Unimplemented")

    def read(self, size):
        raise RuntimeError("Unimplemented")

    def write_read(self, data: bytes, size):
        raise RuntimeError("Unimplemented")

    def flush_input(self):
        raise RuntimeError("Unimplemented")


class SerialPortChannel(Channel):
    def __init__(self, path: str, baud=115200, wait: bool = False):
        #self.ch = Serial(path, baud, timeout=5)
        repeat = True
        while repeat:
            if not wait:
                repeat = False
            try:
                self.ch = Serial(path, baud)
                repeat = False
                if wait:
                    time.sleep(0.5)
            except SerialException as ex:
                if not wait:
                    raise ex

    def write(self, data: bytes):
        self.ch.write(data)

    def read(self, size=1):
        return self.ch.read(size)

    def write_read(self, data: bytes, size=1):
        self.write(data)
        return self.read(size)

    def flush_input(self):
        self.ch.flushInput()


class Error(Exception):
    def __init__(self, command: Command, status):
        self.command = command
        try:
            self.status = Status(status)
        except ValueError:
            self.status = status

        if isinstance(self.status, Status):
            print("ERROR %s: %s" % (self.command.name, self.status.name))
        else:
            print("ERROR %s: %#04x" % (self.command.name, self.status))


class HandshakeError(Exception):
    pass


class ProtocolError(Exception):
    pass


class Device:
    def __init__(self, path, wait: bool, brom_mode: bool = False):
        self.ch = SerialPortChannel(path, wait=wait)
        self.handshake_done = False
        self.brom_mode = brom_mode

    def meta_mode(self):
        assert not self.handshake_done
        self.ch.write(b'METAMETA')
        response = self.ch.read(8)
        if len(response) == 0:
            print("Read timeout while switching to META mode")
        elif response == b'ATEMATEM' or response == b'READYATE':
            print("META Mode enabled")
        elif response == b'METAFORB':
            print("META mode forbidden")
        else:
            print("Invalid response: ", response)

    def handshake(self):
        i = 0
        while True:
            c = self.ch.write_read(b'\xa0')
            if c == b'\x5f':
                break
            if i == 10:
                raise HandshakeError()
            self.ch.flush_input()
            i += 1

        if self.ch.write_read(b'\x0a') != b'\xf5':
            raise HandshakeError()

        if self.ch.write_read(b'\x50') != b'\xaf':
            raise HandshakeError()

        if self.ch.write_read(b'\x05') != b'\xfa':
            raise HandshakeError()

        self.handshake_done = True

        # Disable watchdog
        try:
            self.write32(0x10007000, 0x22000000)
        except:
            print('failed to disable watchdog', flush=True)

    def execute_command(self, cmd: Command, response_size: int = 0, echo: bool = True, status: bool = True, data: bytes = b''):
        if echo:
            i = 0
            while True:
                e = self.ch.write_read(
                    cmd.value + data, size=len(cmd.value + data))
                if e == cmd.value + data:
                    break
                if i == 10:
                    raise ProtocolError()
                i += 1
                # print('? %#x != %#x' % (e[0], cmd.value))
        else:
            self.ch.write(cmd.value + data)

        if status:
            status = unpack(">H", self.ch.read(2))[0]
            if status != 0:
                raise Error(cmd, status)

        if response_size > 0:
            return self.ch.read(response_size)

    def identify(self):
        soc_id, soc_step = unpack(">HH", self.execute_command(
            Command.CMD_GET_HW_CODE, 4, status=False))
        ver, subver, extra = unpack(">HHI", self.execute_command(
            Command.CMD_GET_HW_VER, 8, status=False))
        sec_ver = self.execute_command(
            Command.CMD_GET_SEC_VERSION, 1, status=False, echo=False)
        tgt_cfg = unpack(">I", self.execute_command(
            Command.CMD_GET_TARGET_CONFIG, 4, status=False))[0]

        print("Chip: %x, stepping: %x" % (soc_id, soc_step))
        print("Hardware version: %#x, subversion: %#x, extra: %#x" %
              (ver, subver, extra))
        if sec_ver == b'\xff':
            print("Security: OFF")
        else:
            print("Security: ON (%#x)" % sec_ver[0])

        print("Target Config: %#x" % tgt_cfg)

        if tgt_cfg & 1:
            print("SBC: Enabled")
        else:
            print("SBC: Disabled")

        if tgt_cfg & 2:
            print("SLA: Enabled")
        else:
            print("SLA: Disabled")

        if tgt_cfg & 4:
            print("DAA: Enabled")
        else:
            print("DAA: Disabled")

    def read32(self, addr, size=1):
        self.ch.flush_input()

        self.ch.write(Command.CMD_READ32.value)
        assert self.ch.read() == Command.CMD_READ32.value
        self.ch.write(pack(">I", addr))
        assert unpack(">I", self.ch.read(4))[0] == addr
        self.ch.write(pack(">I", size))
        assert unpack(">I", self.ch.read(4))[0] == size

        status = unpack(">H", self.ch.read(2))[0]
        if status != 0:
            raise Error(Command.CMD_READ32, status)

        result = []
        for _ in range(size):
            data = unpack(">I", self.ch.read(4))[0]
            result.append(data)

        status = unpack(">H", self.ch.read(2))[0]
        if status != 0:
            raise Error(Command.CMD_READ32, status)

        return result

    def write32(self, addr, words):
        if not isinstance(words, list):
            words = [words]

        assert addr % 4 == 0
        assert len(words) > 0
        self.ch.flush_input()

        self.ch.write(Command.CMD_WRITE32.value)
        assert self.ch.read() == Command.CMD_WRITE32.value
        self.ch.write(pack(">I", addr))
        assert unpack(">I", self.ch.read(4))[0] == addr
        self.ch.write(pack(">I", len(words)))
        assert unpack(">I", self.ch.read(4))[0] == len(words)

        expected = 0 if not self.brom_mode else 1

        status = unpack(">H", self.ch.read(2))[0]
        if status != expected:
            raise Error(Command.CMD_WRITE32, status)

        for word in words:
            self.ch.write(pack(">I", word))
            assert unpack(">I", self.ch.read(4))[0] == word

        status = unpack(">H", self.ch.read(2))[0]
        if status != expected:
            raise Error(Command.CMD_WRITE32, status)

    def run_ext_cmd(self, cmd):
        self.ch.write(b'\xC8')
        assert self.ch.read(1) == b'\xC8'
        cmd = bytes([cmd])
        self.ch.write(cmd)
        assert self.ch.read(1) == cmd
        self.ch.read(1)
        self.ch.read(2)

    def upload_da(self, base, payload: bytes, unk: int = 0x100):
        self.ch.flush_input()
        print("uploading payload ... ", flush=True)
        self.execute_command(Command.CMD_SEND_DA, data=pack(">III", *(base, len(payload), unk)))

        num_blocks = len(payload) / BLOCK_SIZE
        if len(payload) % BLOCK_SIZE != 0:
            num_blocks += 1

        print("Uploading %d blocks of size %d bytes " % (num_blocks, BLOCK_SIZE), end='', flush=True)

        while payload:
            print(".", end='', flush=True)
            self.ch.write(payload[:BLOCK_SIZE])
            payload = payload[BLOCK_SIZE:]

        self.ch.read(2)
        status = unpack(">H", self.ch.read(2))[0]
        if status != 0:
            print(' ERROR', flush=True)
            raise Error(Command.CMD_SEND_DA, status)

        print(' DONE', flush=True)
        self.ch.flush_input()
        print('Booting ... ', end='', flush=True)
        self.execute_command(Command.CMD_JUMP_DA, data=pack(">I", base))
        print('DONE', flush=True)
        for _ in range(25):
            print('%x ' % self.ch.read(1)[0], end='', flush=True)
        print('')

    def send_image(self, image_name: bytes, payload: bytes):
        self.execute_command(Command.CMD_SEND_IMAGE, status=False)
        if len(image_name) > 63:
            raise RuntimeError('image_name exceeds 63 bytes')
        self.ch.write(image_name)
        for _ in range(64 - len(image_name)):
            self.ch.write(b'\x00')
        print('send image payload length = %d' % len(payload))
        self.ch.write(pack(">I", len(payload)))
        status = unpack(">H", self.ch.read(2))[0]
        print('send_image status = %#x' % status)
        if len(payload) > 0:
            num_blocks = len(payload) / BLOCK_SIZE
            if len(payload) % BLOCK_SIZE != 0:
                num_blocks += 1

            print("Uploading %d blocks of size %d bytes " %
                  (num_blocks, BLOCK_SIZE), end='', flush=True)

            while payload:
                print(".", end='', flush=True)
                self.ch.write(payload[:BLOCK_SIZE])
                payload = payload[BLOCK_SIZE:]

            print(' DONE', flush=True)

        self.ch.write(pack(">I", 0))

    def boot_image(self, image_name: bytes):
        self.execute_command(Command.CMD_BOOT_IMAGE, status=False)
        if len(image_name) > 63:
            raise RuntimeError('image_name exceeds 63 bytes')
        self.ch.write(image_name)
        for _ in range(64 - len(image_name)):
            self.ch.write(b'\x00')
        status = unpack(">H", self.ch.read(2))[0]
        print('boot image status = %#x' % status)

    def read_file(self, p):
        f = open(p, 'rb')
        d = f.read()
        f.close()
        return d

    def bootcda(self):
        agent = self.read_file('agent/build/agent.bin')
        self.send_image(b'lk', (b'\x00' * 512) + agent)
        self.boot_image(b'lk')

    def test2(self):
        atf = self.read_file('SOFT/ATF1E.IMG')
        lk = self.read_file('SOFT/LK.IMG')
        self.send_image(b'lk', lk)
        # self.ch.flush_input()
        #self.send_image(b'tee1', (b'\x00' * 0xC40) + atf)
        # self.ch.flush_input()
        # self.identify()
        self.boot_image(b'tee1')

    def gcpu_init(self):
        self.write32(CRYPTO_BASE + 0x0C0C, 0)
        self.write32(CRYPTO_BASE + 0x0C10, 0)
        self.write32(CRYPTO_BASE + 0x0C14, 0)
        self.write32(CRYPTO_BASE + 0x0C18, 0)
        self.write32(CRYPTO_BASE + 0x0C1C, 0)
        self.write32(CRYPTO_BASE + 0x0C20, 0)
        self.write32(CRYPTO_BASE + 0x0C24, 0)
        self.write32(CRYPTO_BASE + 0x0C28, 0)
        self.write32(CRYPTO_BASE + 0x0C2C, 0)
        self.write32(CRYPTO_BASE + 0x0C00 + 18 * 4, [0] * 4)
        self.write32(CRYPTO_BASE + 0x0C00 + 22 * 4, [0] * 4)
        self.write32(CRYPTO_BASE + 0x0C00 + 26 * 4, [0] * 8)

    def gcpu_acquire(self):
        self.write32(CRYPTO_BASE, [0x1F, 0x12000])

    def gcpu_call_func(self, func):
        self.write32(CRYPTO_BASE + 0x0804, 3)
        self.write32(CRYPTO_BASE + 0x0808, 3)
        self.write32(CRYPTO_BASE + 0x0C00, func)
        self.write32(CRYPTO_BASE + 0x0400, 0)
        while not self.read32(CRYPTO_BASE + 0x0800)[0]:
            pass
        if self.read32(CRYPTO_BASE + 0x0800)[0] & 2:
            if not self.read32(CRYPTO_BASE + 0x0800)[0] & 1:
                while not self.read32(CRYPTO_BASE + 0x0800)[0]:
                    pass
            result = -1
            self.write32(CRYPTO_BASE + 0x0804, 3)
        else:
            while not self.read32(CRYPTO_BASE + 0x0418)[0] & 1:
                pass
            result = 0
            self.write32(CRYPTO_BASE + 0x0804, 3)
        return result

    def aes_read16(self, addr):
        self.write32(CRYPTO_BASE + 0xC04, addr)
        self.write32(CRYPTO_BASE + 0xC08, 0)  # dst to invalid pointer
        self.write32(CRYPTO_BASE + 0xC0C, 1)
        self.write32(CRYPTO_BASE + 0xC14, 18)
        self.write32(CRYPTO_BASE + 0xC18, 26)
        self.write32(CRYPTO_BASE + 0xC1C, 26)
        if self.gcpu_call_func(126) != 0:  # aes decrypt
            raise Exception("failed to call the function!")
        words = self.read32(CRYPTO_BASE + 0xC00 + 26 *
                            4, 4)  # read out of the IV
        data = b""
        for word in words:
            data += pack("<I", word)
        return data

    def aes_write16(self, addr, data):
        if len(data) != 16:
            raise RuntimeError("data must be 16 bytes")

        pattern = bytes.fromhex("4dd12bdf0ec7d26c482490b3482a1b1f")

        # iv-xor
        words = []
        for x in range(4):
            word = data[x*4:(x+1)*4]
            word = unpack("<I", word)[0]
            pat = unpack("<I", pattern[x*4:(x+1)*4])[0]
            words.append(word ^ pat)

        self.write32(CRYPTO_BASE + 0xC00 + 18 * 4, [0] * 4)
        self.write32(CRYPTO_BASE + 0xC00 + 22 * 4, [0] * 4)
        self.write32(CRYPTO_BASE + 0xC00 + 26 * 4, [0] * 8)

        self.write32(CRYPTO_BASE + 0xC00 + 26 * 4, words)

        # src to VALID address which has all zeroes (otherwise, update pattern)
        self.write32(CRYPTO_BASE + 0xC04, 0)
        self.write32(CRYPTO_BASE + 0xC08, addr)  # dst to our destination
        self.write32(CRYPTO_BASE + 0xC0C, 1)
        self.write32(CRYPTO_BASE + 0xC14, 18)
        self.write32(CRYPTO_BASE + 0xC18, 26)
        self.write32(CRYPTO_BASE + 0xC1C, 26)
        if self.gcpu_call_func(126) != 0:  # aes decrypt
            raise RuntimeError("failed to call the function!")

    def unlock(self):
        if self.brom_mode:
            print('Disabling caches ... ', end='', flush=True)
            self.run_ext_cmd(0xB1)
            print('DONE', flush=True)

            self.gcpu_init()
            self.gcpu_acquire()
            self.gcpu_init()
            self.gcpu_acquire()
            print('GCPU init OK', flush=True)

            print('Disabling range checks ... ', end='', flush=True)
            self.aes_write16(0x10276C, bytes.fromhex(
                "00000000000000000000000080000000"))
            print('DONE', flush=True)

            # disable secure boot in preloader
            self.write32(0x102080, 0x3B6C243C)
            self.write32(0x00102084, 0xF843E0A)

            print('Restarting BOOTROM ...', flush=True)
            # jump
            self.write32(0x1027AC, 0)
        else:
            #print('Unsupported in preloader mode')
            self.send_image(b'lk', (b'\x00' * 512) +
                            self.read_file('payload.bin'))
            self.boot_image(b'lk')
            print('waiting for payload to come online', flush=True)
            # print('x=',self.ch.read(4),flush=True)
            while True:
                print(self.ch.read().decode('utf-8'), flush=True, end='')

    def dump_brom(self):
        self.gcpu_init()
        self.gcpu_acquire()
        self.gcpu_init()
        self.gcpu_acquire()
        self.run_ext_cmd(0xB1)
        f = open('__BROM.BIN', 'wb')
        for x in range(0, 0x20000, 16):
            f.write(self.aes_read16(x))
            print('.', end='', flush=True)
        f.close()

    def wdt_reboot(self):
        self.write32(0x10007008, 0x1971)
        v = self.read32(0x10007000)[0]
        v &= ~0x10
        v &= ~0x48
        v |= 0x22000014
        self.write32(0x10007000, v)
        time.sleep(0.4)
        self.write32(0x10007014, 0x1209)
        
    def crash_preloader(self):
        self.upload_da(0, b'\x00' * 0x100)


